
#GNN


import torch
import torch.nn as nn

batch_size = 1
input_size = 4
input_seq = 5
hidden_size = 4

idx2char = ['e', 'h', 'l', 'o']

x_data = [1, 0, 2, 2, 3]
y_data = [3, 1, 2, 3, 2]

one_hot_lookup = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]

x_one_hot = [one_hot_lookup[x] for x in x_data]

print(f"x_one_hot: {x_one_hot}")

inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
print(f"inputs :{inputs}")
labels = torch.LongTensor(y_data)
print(f"labels :{labels}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Model(nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
        super(Model, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size

        self.rnn = nn.RNN(input_size, hidden_size, num_layers=1)

    def forward(self, inputs):
        h0 = torch.zeros(1, self.batch_size, self.hidden_size).to(device)
        out, _ = self.rnn(inputs, h0)
        return out.view(-1, self.hidden_size)


net = Model(input_size, hidden_size, batch_size).to(device)
print(net.parameters())

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

epochs = 15
for epoch in range(epochs):
    optimizer.zero_grad()
    output = net(inputs.to(device))
    loss = criterion(output, labels.to(device))
    loss.backward()
    optimizer.step()
    print(f'Pred: {output.max(dim=1)[1]}  labels: {labels}    Epoch [{epoch + 1}/{epochs}]    loss = {loss.item():.4f}')